from database.connection import get_db_connection


async def init_database():
    """初始化数据库表结构"""
    conn = await get_db_connection()
    try:
        # 创建会话表
        await conn.execute("""
                           CREATE TABLE IF NOT EXISTS sessions
                           (
                               id           TEXT PRIMARY KEY,
                               created_at   TEXT NOT NULL,
                               updated_at   TEXT NOT NULL,
                               last_message TEXT NOT NULL
                           )
                           """)

        # 创建消息表
        await conn.execute("""
                           CREATE TABLE IF NOT EXISTS messages
                           (
                               id         TEXT PRIMARY KEY,
                               session_id TEXT    NOT NULL,
                               content    TEXT    NOT NULL,
                               is_user    BOOLEAN NOT NULL,
                               created_at TEXT    NOT NULL,
                               FOREIGN KEY (session_id) REFERENCES sessions (id)
                           )
                           """)

        # 创建Prompt模板表
        await conn.execute("""
                           CREATE TABLE IF NOT EXISTS prompt_templates
                           (
                               id          TEXT PRIMARY KEY,
                               name        TEXT    NOT NULL,
                               scenario    TEXT    NOT NULL,
                               content     TEXT    NOT NULL,
                               description TEXT,
                               is_default  BOOLEAN NOT NULL DEFAULT 0,
                               variables   TEXT    NOT NULL,
                               created_at  TEXT    NOT NULL,
                               updated_at  TEXT    NOT NULL,
                               UNIQUE (scenario, is_default)
                           )
                           """)

        # 创建索引
        await conn.execute("CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages (session_id)")
        await conn.execute("CREATE INDEX IF NOT EXISTS idx_templates_scenario ON prompt_templates (scenario)")

        await conn.commit()
    finally:
        await conn.close()


if __name__ == '__main__':
    import asyncio

    asyncio.run(init_database())
